# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''data conversion'''


from mindspore.mindrecord import FileWriter
import argparse
from mindspore import context
import pandas as pd
import numpy as np
import os
import librosa

def compute_melgram(audio_path, save_path = '', filename = '', save_npy = True):
    # mel-spectrogram parameters
    SR = 12000
    N_FFT = 512
    N_MELS = 96
    HOP_LEN = 256
    DURA = 29.12  # to make it 1366 frame..

    src, sr = librosa.load(audio_path, sr=SR)  # whole signal
    n_sample = src.shape[0]
    n_sample_fit = int(DURA*SR)

    if n_sample < n_sample_fit:  # if too short
        src = np.hstack((src, np.zeros((int(DURA*SR) - n_sample,))))
    elif n_sample > n_sample_fit:  # if too long
        src = src[(n_sample-n_sample_fit)//2:(n_sample+n_sample_fit)//2]
    logam = librosa.core.amplitude_to_db
    melgram = librosa.feature.melspectrogram
    ret = logam(melgram(y=src, sr=SR, hop_length=HOP_LEN,n_fft=N_FFT, n_mels=N_MELS))
    ret = ret[np.newaxis, np.newaxis, :]
    if save_npy:

        save_path = save_path + filename[:-4] + '.npy'
        np.save(save_path, ret)
    return ret

def get_data(features_data, labels_data):
    data_list = []
    for i, (label, feature) in enumerate(zip(labels_data, features_data)):
        data_json = {"id": i,
                    "feature": feature,
                    "label": label}
        data_list.append(data_json)
    return data_list
   

def convert(s):
    if s.isdigit():
        return int(s)
    return s
 
def GetLabel(info_path, info_name):
    T = []
    with open(info_path + '/' + info_name,'rb') as f:
        data = f.readline()
        while data:
            T.append([convert(i[1:-1]) for i in data.strip().decode('utf-8').split("\t")])
            data = f.readline()
   
    annotation = pd.DataFrame(T[1:],columns = T[0])
    count = []
    for i in annotation.columns[1:-2]:  
        count.append([annotation[i].sum()/ len(annotation),i])
    count = sorted(count)
    full_label = []
    for i in count[-50:]:
        full_label.append(i[1])
    out = []
    for i in T[1:]:
        index = [k for k,x in enumerate(i) if x==1]
        label = [T[0][k] for k in index]
        L = [str(0) for k in range(50)]
        L.append(i[-1])
        for j in label:
            if j in full_label:
                ind = full_label.index(j)
                L[ind] = '1'
        out.append(L)
    out = np.array(out)
    
    Train = []
    Val = []

    for i in out:
        if np.random.rand() > 0.2:
            Train.append(i)
        else:
            Val.append(i)
    np.savetxt("{}/train2.csv".format(info_path), np.array(Train), fmt='%s', delimiter=',')
    np.savetxt("{}/val2.csv".format(info_path), np.array(Val), fmt='%s', delimiter=',')
    
def generator_md(info_name, file_path, num_classes):
    df = pd.read_csv(info_name, header = None)
    df.columns = [str(i) for i in range(num_classes)] + ["mp3_path"]
    data = []
    label = []
    for i in range(len(df)):
        try:
            data.append(np.load(file_path + df.mp3_path.values[i][:-4] + '.npy').reshape(1,96,1366))
            label.append(np.array(df[df.columns[:-1]][i:i+1])[0])
        except:
            pass
    return np.array(data), np.array(label,dtype=np.int32)
    
def convert_to_mindrecord(info_name, file_path, store_path, mr_name, num_classes):
    """ convert dataset to mindrecord """
    num_shard = 4
    data, label = generator_md(info_name, file_path, num_classes)
    schema_json = {"id": {"type": "int32"},
                "feature": {"type": "float32", "shape": [1,96,1366]},
                "label": {"type": "int32", "shape": [num_classes]}}

    writer = FileWriter(os.path.join(store_path, '{}.mindrecord'.format(mr_name)), num_shard)
    datax = get_data(data,label)
    writer.add_schema(schema_json, "music_tagger_schema")
    writer.add_index(["id"])
    writer.write_raw_data(datax)
    writer.commit()
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='get feature')
    
    parser.add_argument('--get_npy', type=int, help='step 1: extract feature to npy', default = 0)
    parser.add_argument('--get_mindrecord', type=int, help='step 2: convert npy to mindrecord', default = 1)
    parser.add_argument('--npy_path', type=str, help='path to save/load npy file')
    parser.add_argument('--info_path', type=str, help='path to annotations_final.csv')
    parser.add_argument('--info_name', type=str, help='name info file', default = "annotations_final.csv")
    parser.add_argument('--file_path', type=str, help='path to read audio file', default = '')
    parser.add_argument('--mr_info', type=str, help='info to guide mindrecord conversion', default = 'train.csv')
    parser.add_argument('--mr_path', type=str, help='path to save mindrecord', default = '')
    parser.add_argument('--mr_name', type=str, help='name of mindrecord', default = "set")
    parser.add_argument('--npu', type=int, help='device ID', default = 0)
    parser.add_argument('--num_classes', type=int, help='num_classes', default = 50)
    
    args = parser.parse_args()

    if args.get_npy:
        GetLabel(args.info_path, args.info_name)
        dirname = os.listdir(args.file_path)
        for i in dirname:
            filename = os.listdir("{}/{}".format(args.file_path,i))
            if not os.path.isdir("{}/{}".format(args.npy_path,i)):
                os.mkdir("{}/{}".format(args.npy_path,i))
            for j in filename:
                compute_melgram("{}/{}/{}".format(args.file_path,i,j), "{}/{}/".format(args.npy_path,i), j)

    if args.get_mindrecord:
        context.set_context(device_target='Ascend',mode = context.GRAPH_MODE, device_id = args.npu)
        convert_to_mindrecord(args.mr_info, args.npy_path, args.mr_path, args.mr_name, args.num_classes)
